2702aa
@@ -22,6 +22,7 @@
import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.lang.reflect.Proxy;
 import java.net.URI;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedList;
@@ -101,6 +102,11 @@
public class CachingConnectionFactory extends AbstractConnectionFactory
 
 	private static final int DEFAULT_CHANNEL_CACHE_SIZE = 25;
 
+	private static final Set<String> txStarts = new HashSet<>(Arrays.asList("basicPublish", "basicAck", "basicNack",
+			"basicReject"));
+
+	private static final Set<String> txEnds = new HashSet<>(Arrays.asList("txCommit", "txRollback"));
+
 	private final ChannelCachingConnectionProxy connection = new ChannelCachingConnectionProxy(null);
 
 	public enum CacheMode {
@@ -867,8 +873,6 @@
public class CachingConnectionFactory extends AbstractConnectionFactory
 
 		private final ChannelCachingConnectionProxy theConnection;
 
-		private volatile Channel target;
-
 		private final LinkedList<ChannelProxy> channelList;
 
 		private final String channelListIdentity;
@@ -877,6 +881,10 @@
public class CachingConnectionFactory extends AbstractConnectionFactory
 
 		private final boolean transactional;
 
+		private volatile Channel target;
+
+		private volatile boolean txStarted;
+
 		CachedChannelInvocationHandler(ChannelCachingConnectionProxy connection,
 				Channel target,
 				LinkedList<ChannelProxy> channelList,
@@ -941,13 +949,26 @@
public class CachingConnectionFactory extends AbstractConnectionFactory
 						this.target.close();
 						throw new InvocationTargetException(new AmqpException("PublisherCallbackChannel is closed"));
 					}
+					else if (this.txStarted) {
+						this.txStarted = false;
+						throw new IllegalStateException("Channel closed during transaction");
+					}
 					this.target = null;
 				}
 				synchronized (this.targetMonitor) {
 					if (this.target == null) {
 						this.target = createBareChannel(this.theConnection, this.transactional);
 					}
-					return method.invoke(this.target, args);
+					Object result = method.invoke(this.target, args);
+					if (this.transactional) {
+						if (txStarts.contains(methodName)) {
+							this.txStarted = true;
+						}
+						else if (txEnds.contains(methodName)) {
+							this.txStarted = false;
+						}
+					}
+					return result;
 				}
 			}
 			catch (InvocationTargetException ex) {
